import torch
from torch import nn
mask_feature = torch.rand(size=(10,256,112,112))

gate = nn.Linear(64,64)

mask_feature2 = gate(mask_feature)

sh = mask_feature2.shape

sh[0] = 1